import os
import os.path as osp
import glob
import gym
import numpy as np
import yaml

import diffgro
from diffgro.environments import make_env as _make_env
from diffgro.environments.variant import *
from diffgro.common.buffers import load_traj, MTTrajectoryBuffer
from diffgro.utils.config import load_config
from diffgro.utils import Parser, make_dir, print_r, print_y, print_b


def make_env(args):
    print_r(f"<< Making Environment for {args.env_name}... >>")
    domain_name, task_name = args.env_name.split(".")
    env = _make_env(domain_name, task_name)

    if domain_name == "metaworld":
        print_b(f"Setting goal resistance to {args.goal_resistance}")
        env.variant_space.variant_config["goal_resistance"] = Categorical(
            a=[args.goal_resistance]
        )
    elif domain_name == "metaworld_complex":
        env.variant_space.variant_config["goal_resistance"] = VariantSpace(
            {
                "handle": Categorical(a=[0]),
                "button": Categorical(a=[0]),
                "drawer": Categorical(a=[args.goal_resistance]),
                "lever": Categorical(a=[0]),
                "door": Categorical(a=[0]),
            }
        )
    else:
        raise NotImplementedError
    env.variant_space.variant_config["arm_speed"] = Categorical(a=[1.0])
    env.variant_space.variant_config["wind_xspeed"] = Categorical(a=[1.0])
    env.variant_space.variant_config["wind_yspeed"] = Categorical(a=[1.0])
    return env, domain_name, task_name


def make_buff(args, env):
    domain_name, task_name = args.env_name.split(".")

    buff = MTTrajectoryBuffer(8, False, env.observation_space, env.action_space)
    if args.phase == 0:
        print_r(f"<< Making Buffer for {args.env_name}... >>")
        task_paths = osp.join(args.dataset_path, domain_name, task_name)
        traj_paths = glob.glob(osp.join(task_paths, "trajectory", "*.pkl"))
        traj = [load_traj(path) for path in traj_paths]
        print_b(f">> Number of traj is {len(traj)}")
        buff.add_task(traj)
        task_list = [task_paths.split("/")[-1]]
    elif args.phase == 1:
        print_r(f"<< Making Buffer for {domain_name}... >>")
        dataset_path = osp.join(args.dataset_path, domain_name)
        task_paths = glob.glob(osp.join(dataset_path, "*variant*"))
        for task_path in task_paths:
            print_b(f"Adding task at {task_path}")
            traj_paths = glob.glob(osp.join(task_path, "trajectory", "*.pkl"))
            traj = [load_traj(path) for path in traj_paths]
            print_b(f">> Number of traj is {len(traj)}")
            buff.add_task(traj)
        task_list = [path.split("/")[-1] for path in task_paths]
    return buff, task_list


def make_context(domain_name, task_name, multimodal=False):
    with open(f"./config/contexts/{domain_name}/{task_name}.yml") as f:
        if not multimodal:
            contexts = yaml.load(f, Loader=yaml.FullLoader)[task_name]["text"]
        else:
            contexts = yaml.load(f, Loader=yaml.FullLoader)[task_name]["multimodal"]
    return contexts


def eval_save(tot_success, save_path):
    avg_success, std_success = (
        np.mean(tot_success, axis=0) * 100,
        np.std(tot_success, axis=0) * 100,
    )
    print_r("=" * 13 + f" Total Performance " + "=" * 13)
    print(
        f"\tTotal Success Rate : {avg_success:.3f} +\- {std_success:.3f} ({std_success / len(tot_success) * 1.96})"
    )
    print_r("=" * 50)

    with open(os.path.join(save_path, "evaluation.txt"), "a") as f:
        f.write("=" * 13 + f" Total Performance " + "=" * 13 + "\n")
        f.write(
            f"\tTotal Success Rate : {avg_success:.3f} +\- {std_success:.3f} ({std_success / len(tot_success) * 1.96})\n"
        )
        f.write("=" * 50 + "\n")
